// Copyright (c) 2020-present, INSPUR Co, Ltd. All rights reserved.
// This source code is licensed under Apache 2.0 License.

#include <assert.h>
#include <algorithm>
#include <functional>
#include <iostream>
#include "art_tree.h"
#include "pure_mem/epoche.h"

namespace art_rowex {
void Tree::insertOrGetLG(const Key &k, TID tid, TID &curTID, TID &nextTID) {

restart:
  bool needRestart = false;

  N *node = nullptr;
  N *nextNode = root;
  N *parentNode = nullptr;
  uint8_t parentKey, nodeKey = 0;
  uint32_t level = 0;

  while (true) {
    parentNode = node;
    parentKey = nodeKey;
    node = nextNode;
    auto v = node->getVersion();

    uint32_t nextLevel = level;
    uint8_t nonMatchingKey;
    Prefix remainingPrefix;
    switch (checkPrefixMatch(node, k, nextLevel, nonMatchingKey,
                             remainingPrefix, this->loadKey_)) {
      case CheckPrefixPessimisticResult::SkippedLevel:
        goto restart;
      case CheckPrefixPessimisticResult::NoMatch: {  
        // prefix not match, or key length is not enough
        assert(nextLevel <= k.getKeyLen());  
        // nextLevel can larger than key. for use 0 position of nextlevel to store key.
        // for example, input key A and AB, level 0 stores index A, how index A and AB?
        // we use level 1 store two index: 0 and B, use 0 to index key A, and B index key AB.
        node->lockVersionOrRestart(v, needRestart);
        if (needRestart) goto restart;

        // 1) Create new node which will be parent of node, Set common prefix,
        // level to this node
        Prefix prefi = node->getPrefi();
        prefi.prefixCount = nextLevel - level;
        auto newNode = new N4(nextLevel, prefi);

        // 2)  add node and (tid, *k) as children
        newNode->insert(nextLevel < k.getKeyLen() ? k[nextLevel] : 0,
                        N::setLeaf(tid));
        newNode->insert(nonMatchingKey, node);

        // 3) lockVersionOrRestart, update parentNode to point to the new node,
        // unlock
        parentNode->writeLockOrRestart(needRestart);
        if (needRestart) {
          delete newNode;
          node->writeUnlock();
          goto restart;
        }
        N::change(parentNode, parentKey, newNode);
        parentNode->writeUnlock();

        // 4) update prefix of node, unlock
        node->setPrefix(remainingPrefix.prefix, node->getPrefi().prefixCount -
                                                    ((nextLevel - level) + 1));

        node->writeUnlock();

        curTID = tid;
        if (nextLevel >= k.getKeyLen() || k[nextLevel] < nonMatchingKey) {
          nextTID = getMinTID(node, 0);
        } else {
          nextTID = getMaxTID(node, 255);
        }

        return;
      }
      case CheckPrefixPessimisticResult::Match: {
        // also process keys which is empty or null
        if (nextLevel == k.getKeyLen()) {
          // key is totally matched. token[0] in next level store TID of current key 
          nodeKey = 0;
          Prefix p = node->getPrefi();
          if (p.prefixCount > 0 && level == nextLevel) {
            // prefix that small than key length is compeltely match, 
            // but key length is not enough match all prefix.
            node->lockVersionOrRestart(v, needRestart);
            if (needRestart) goto restart;
            Prefix pp;
            auto newNode = new N4(level, pp);
            newNode->insert(nodeKey, N::setLeaf(tid));
            newNode->insert(p.prefix[0], node);

            parentNode->writeLockOrRestart(needRestart);
            if (needRestart) {
              delete newNode;
              node->writeUnlock();
              goto restart;
            }
            N::change(parentNode, parentKey, newNode);
            parentNode->writeUnlock();
            node->setPrefix(&p.prefix[1], p.prefixCount - 1);
            node->writeUnlock();

            // TID of current key is already created. so no need search TID of larger key.
            curTID = tid;
            nextTID = getMinTID(node, 0);
            return;
          }

          nextNode = N::getChild(nodeKey, node);
          assert( nextNode == nullptr || N::isLeaf(nextNode));
          if (N::isLeaf(nextNode)) {
            // TID of current key is already created. so no need search TID of larger key.
            curTID = N::getLeaf(nextNode);
            nextTID = 0;
            return;
          } else if (nextNode == nullptr) {
            node->lockVersionOrRestart(v, needRestart);
            if (needRestart) goto restart;
            N::insertAndUnlock(node, parentNode, parentKey, nodeKey,
                               N::setLeaf(tid), needRestart);
            if (needRestart) goto restart;

            curTID = tid;
            // find min TID of current node as larger key.
            nextTID = getMinTID(node, nodeKey + 1);
            if (nextTID == 0) {
              nextTID = getMaxTID(node, nodeKey - 1);
            }
            return;
          }
        }
        break;
      }
    }

    level = nextLevel;
    nodeKey = k[level];
    nextNode = N::getChild(nodeKey, node);

    if (nextNode == nullptr) {
      node->lockVersionOrRestart(v, needRestart);

      if (needRestart) goto restart;
      N::insertAndUnlock(node, parentNode, parentKey, nodeKey, N::setLeaf(tid),
                         needRestart);

      if (needRestart) goto restart;

      curTID = tid;
      nextTID = getMinTID(node, nodeKey + 1);
      if (nextTID == 0) {
        nextTID = getMaxTID(node, nodeKey - 1);
      }
      return;
    }

    if (N::isLeaf(nextNode)) {
      level++;

      node->lockVersionOrRestart(v, needRestart);
      if (needRestart) goto restart;

      Key key;
      loadKey_(N::getLeaf(nextNode), key);

      uint32_t prefixLength = 0;
      while ((level + prefixLength) <
                 std::min(key.getKeyLen(), k.getKeyLen()) &&
             key[level + prefixLength] == k[level + prefixLength]) {
        prefixLength++;
      }

      if ((level + prefixLength) == k.getKeyLen() &&
          k.getKeyLen() == key.getKeyLen()) {  // find TID of current key
        node->writeUnlock();
        curTID = N::getLeaf(nextNode);
        nextTID = 0;  // TID of current key  exists,so no need return larger one. 
        return;
      }

      auto n4 =
          new N4(level + prefixLength,
                 k.getKeyLen() > level ? &k[level] : nullptr, prefixLength);
      n4->insert(
          (level + prefixLength) >= k.getKeyLen() ? 0 : k[level + prefixLength],
          N::setLeaf(tid));
      n4->insert((level + prefixLength) >= key.getKeyLen()
                     ? 0
                     : key[level + prefixLength],
                 nextNode);

      N::change(node, k[level - 1], n4);
      node->writeUnlock();

      curTID = tid;
      nextTID = N::getLeaf(nextNode);  // nextNode is nearest by current key.
      return;
    }
    level++;
  }
}

TID Tree::getMinTID(N *node, uint8_t start) const {
  if (node == NULL) return 0;
  if (N::isLeaf(node)) {
    return N::getLeaf(node);
  }
  std::tuple<uint8_t, N *> children[256];
  uint32_t childCount = 0;
  N::getChildren(node, start, 255, children, childCount);

  for (uint32_t i = 0; i < childCount; i++) {
    TID ret = getMinTID(std::get<1>(children[i]), 0);
    if (ret != 0) return ret;
  }
  return 0;
}

TID Tree::getMaxTID(N *node, uint8_t end) const {
  if (node == NULL) return 0;
  if (N::isLeaf(node)) {
    return N::getLeaf(node);
  }
  std::tuple<uint8_t, N *> children[256];
  uint32_t childCount = 0;
  N::getChildren(node, 0, end, children, childCount);

  for (uint32_t i = 0; i < childCount; i++) {
    TID ret = getMaxTID(std::get<1>(children[childCount - 1 - i]), 255);
    if (ret != 0) return ret;
  }
  return 0;
}

TID Tree::seek(const Key &k, bool equal, bool greater) const {
  TID ret = seekChildren(root, k, 0, equal, greater);
  return ret;
}

TID Tree::seekChildren(N *node, const Key &k, uint32_t level, bool equal,
                       bool greater) const {
  if (node == nullptr) {
    return 0;
  }

  if (N::isLeaf(node)) {  // two scene: path completely match or current node path compressed
    TID tid = N::getLeaf(node);
    // path completely match ,then TID is the one
    if (equal && greater && k.getKeyLen() <= level)
      return tid;

    int compp = compareKeyPrefix(tid, k, std::numeric_limits<int>::max());
    if (compp < 0 || (compp > 0 && !greater) || (compp == 0 && !equal))
      return 0;

    return tid;
  }
  switch (comparePrefixOnly(node, k, level)) {
    // path prefix in current node < key, so all children of current node < key.
    case -1:
      return 0;
    // path prefix in current node > key, so that the smallest leaf of current node > key.
    case 1:
    {
      if (!greater) return 0;
      return getMinTID(node, 0);
    }
    // probably satisfied, but prefix buffer length max is 4.so need deeply comparation.
    case 2:
    {
      TID min = getMinTID(node, 0);
      int comp = compareKeyPrefix(min, k, level + node->getPrefi().prefixCount);
      if ((comp > 0 && !greater) || comp < 0) return 0;
      if (comp > 0 && greater) return min;

      level += node->getPrefi().prefixCount;
      break;
    }
    case 0:
      level += node->getPrefi().prefixCount;
      break;
  }
  {  
    // competely match prefix 
    uint8_t start = 0, end = 255;
    if (k.getKeyLen() > level) {
      start = k[level];
    }
    if (!equal) {
      if (start == 255) return 0;
      start += 1;
    }
    if (!greater) {
      end = start;
    }

    std::tuple<uint8_t, N *> children[256];
    uint32_t childCount = 0;
    N::getChildren(node, start, end, children, childCount);

    TID ret;
    for (uint32_t i = 0; i < childCount; i++) {
      bool pathMatch = equal && (k.getKeyLen() <= level ||
                                 std::get<0>(children[i]) == k[level]);
      bool justGetMin = !pathMatch && greater;

      if (justGetMin)
        ret = getMinTID(std::get<1>(children[i]), 0);
      else
        ret = seekChildren(std::get<1>(children[i]), k, level + 1, pathMatch,
                           greater);
      if (ret != 0) {
        return ret;
      }
    }
  }

  return 0;
}

int Tree::comparePrefixOnly(N *n, const Key &k, uint32_t level) {
  Prefix p = n->getPrefi();
  if (p.prefixCount > 0) {
    for (uint32_t i = 0; i < std::min(p.prefixCount, maxStoredPrefixLength);
         ++i, ++level) {
      if (level >= k.getKeyLen()) return 1;
      if (p.prefix[i] > k[level]) {
        return 1;
      } else if (p.prefix[i] < k[level])
        return -1;
    }

    if (p.prefixCount > maxStoredPrefixLength) {
      return 2;  // scene, just match prefix in prefix buffer, not all prefix.
    }
  }
  return 0;
}

int Tree::compareKeyPrefix(const TID tid, const Key &k, int length) const {
  Key kt;
  this->loadKey_(tid, kt);

  int keyLength = std::min(k.getKeyLen(), kt.getKeyLen());
  int comLength = std::min(keyLength, length);
  for (int i = 0; i < comLength; i++) {
    if (kt[i] > k[i]) {
      return 1;
    } else if (kt[i] < k[i])
      return -1;
  }
  if (comLength == length) {
    if (k.getKeyLen() == kt.getKeyLen()) {
      for (uint32_t i = comLength; i < k.getKeyLen(); i++) {
        if (kt[i] != k[i]) return 0;
      }
    }
    return 2;  // two key is  same.
  }
  if (kt.getKeyLen() > k.getKeyLen()) return 1;
  if (kt.getKeyLen() < k.getKeyLen()) return -1;
  return 0;
}

Tree::CheckPrefixPessimisticResult Tree::checkPrefixMatch(
    N *n, const Key &k, uint32_t &level, uint8_t &nonMatchingKey,
    Prefix &nonMatchingPrefix, LoadKeyFunction loadKey) {
  Prefix p = n->getPrefi();
  if (p.prefixCount + level < n->getLevel()) {
    return CheckPrefixPessimisticResult::SkippedLevel;
  }
  if (p.prefixCount > 0) {
    uint32_t prevLevel = level;
    Key kt;
    for (uint32_t i = ((level + p.prefixCount) - n->getLevel());
         i < p.prefixCount; ++i) {
      if (i == maxStoredPrefixLength) {
        loadKey(N::getAnyChildTid(n), kt);
      }

      uint8_t curKey = i >= maxStoredPrefixLength ? kt[level] : p.prefix[i];
      if (k.getKeyLen() <= level || curKey != k[level]) {
        nonMatchingKey = curKey;
        if (p.prefixCount > maxStoredPrefixLength) {
          if (i < maxStoredPrefixLength) {
            loadKey(N::getAnyChildTid(n), kt);
          }
          for (uint32_t j = 0;
               j < std::min((p.prefixCount - (level - prevLevel) - 1),
                            maxStoredPrefixLength);
               ++j) {
            nonMatchingPrefix.prefix[j] = kt[level + j + 1];
          }
        } else {
          for (uint32_t j = 0; j < p.prefixCount - i - 1; ++j) {
            nonMatchingPrefix.prefix[j] = p.prefix[i + j + 1];
          }
        }
        return CheckPrefixPessimisticResult::NoMatch;
      }
      ++level;
    }
  }
  return CheckPrefixPessimisticResult::Match;
}

}  // namespace art_rowex